%% SU-MIMO channel estimation example
% This code illustrate SU-MIMO channel estimation 
% This code can be used to generate data for Fig. 7
% Data generation with a large number of monte-carlo runs can take
% excessive amount of time. Kindly use the "Fig7x.m" to generate the curve from the provided data.

clear all
close all
clc

L = 2; %number of paths
aspread_aoa = 45; % AoA is drawn randomly from -aspread_aoa to aspread_aoa
no_sim = 10; % Monte carlo runs
T_s = 1; % number of soundings 
M = 128; % grid size for AoD

fade = 2; 
% fade - nature of fade. 
%       0 is pure Los (unit modulus path gains)
%       1 is Rayleigh fading
%       2 is arbitrary fading

dist_min = 0.5; % parameter for truncated Gaussian fading (if needed)
% Example: 
%     pd = makedist('Normal',0,sqrt(0.5));  trunc_pd = truncate(pd,dist_min,inf);
%     Results in real and imaginary parts of path gains (say, \alpha) to follow a
%     truncated Gaussian distribution with 0 <= |Re(\alpha)|,|Im(\alpha)| <= dist_min 


OSR = 4; % spatial oversampling ratio
N_t = 32; % number of antennas at UE
N_r = (128/OSR)*OSR; % number of antennas at BS
d_r = 0.5/OSR; % inter-element spacing at the BS
d_t = 0.5; % inter-element spacing at the UE


SNR_dB = -16:4:0; % 16;
SNR = 10.^(0.1.*SNR_dB);

aoa_grid = -70:1:70; % Grid for AoA estimation
spacing_aoa = 20; % minumum spacing between AoAs


% codebook parameters
u_grid = linspace(-1,1,M);  
% k th grid point corresponds to : sin^-1( -1 + (2/(M-1))(k-1)   ) degrees
spacing_aod = 20; % minimum spacing for AoDs

edge_skip = 3; 
% edge_skip = 3 => considered grid points: 4 to M-3
% edge_skip = 3 => AoDs lie in  approximately -75^0 to 75^0 

l_eff = M - 2*edge_skip;

T_1  = 10; % Number of snapshots for AoA estimation



tic
        
        
MSE_uq_channel = zeros(1, length(SNR));
MSE_uqf_channel = zeros(1, length(SNR));
MSE_sd_channel = zeros(1,length(SNR));
MSE_ar_channel = zeros(1,length(SNR));


w_opt = [1; zeros(N_t-1,1)]; % precoder for AoA estimation - isotropic transmission
psi_1 = 0;
U = tril(ones(N_r));
U_inv = inv(U);

pd = makedist('Normal',0,sqrt(0.5));
trunc_pd = truncate(pd,dist_min,inf);

% Pilot length for AR is selectd to match the overall channel estimation
% overhead of the proposed method. 
S_len_ar = max(T_1 + 2*L*T_s*log2(M),N_t);  % Overall channel estimation overhead

Itermax = 1; % we use Itermax = 10 in the paper
% Maximum number of iterations of AR algorithm. Increasing this value leads to slower runtime


for iter=1:length(SNR)
    
    err_uq_channel = 0;
    err_uqf_channel = 0;
    err_sd_channel = 0;
    err_ar_channel = 0;
    ch_dr = 0;
    
    
    % Voltage level (c) selection (Equation. number)
    lev_s1 = 3*sqrt((SNR(iter)+1)/2); % Step 1,2
    lev_s2 =  3*sqrt((SNR(iter)*N_t+1)/2);  % Step 3 
    
    R_n = eye(N_r) + ((2*lev_s1^2)/3).*U_inv*U_inv'; % Noise covariance
    R_n_pw = inv(sqrtm(R_n)); % pre-whitening matrix
    
    for loop_inner = 1:no_sim
        
        % channel realization
        
        % path gain generation
        if(fade == 1)
            alpha = sqrt(1/2).*( randn(L,1) + (1j).*randn(L,1));
        elseif(fade == 2)
            s_set_r = 2.*(rand(L,1)>0.5)-1;
            s_set_i = 2.*(rand(L,1)>0.5)-1;
            a_r = s_set_r.*random(trunc_pd,L,1);
            a_i = s_set_i.*random(trunc_pd,L,1);
            alpha = a_r + (1j).*a_i;
        elseif(fade == 0) % LoS case
            path_gain_angle = 2*pi*rand(L,1) - pi;
            alpha = ones(L,1).*exp(1j.*path_gain_angle);  
        end
        
        % AoA generation
        
        n=2*aspread_aoa;
        k=L;
        b=spacing_aoa;  % Min. spacing 
        [as,is]=sort(randperm(n-(k-1)*(b-1),k)); % if it throws an error then b/k/n are not compatible
        a = as + (0:k-1)*(b-1);
        a = a(is) - aspread_aoa;
        doa_true = sort(a);
        
        % generate AoD
        
        n=l_eff;
        k=L;
        b=spacing_aod;
        [as,is]=sort(randperm(n-(k-1)*(b-1),k)); % if it throws an error then b/k/n are not compatible
        a = as + (0:k-1)*(b-1);
        dod_true_dc_ind = edge_skip + sort(a);
        dod_true_dc = u_grid(dod_true_dc_ind);
        dod_true = asind(dod_true_dc);
        
        A_TX = gen_a_v2(N_t,d_t, dod_true_dc);
        A_RX = gen_a(N_r,d_r, doa_true);     
        H_true = sqrt(1/L).*A_RX*diag(alpha)*A_TX';        
        ch_dr = ch_dr + norm(H_true,'fro')^2;
        
        % data transmission (Step1)
        
        S = w_opt*sqrt(SNR(iter)).*ones(1,T_1);                  
        N_noise = sqrt(1/2).*(randn(N_r,T_1) + (1j).*randn(N_r,T_1));
        X = H_true*S + N_noise; % received unquantized signal
        X_true = X;
        
       % Doa estimation (Step1) %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        
        psi_1 = 0; % Angle steered to broadside for AoA estimation     
        b = lev_s1;
        Y = sigma_delta_ADC(X, psi_1, d_r, b);
        Y_true = Y;

        R_x = (1/T_1).*X*X';
        R_y = (1/T_1).*Y*Y';
        cost_a = zeros(1,length(aoa_grid));
        cost_sd = zeros(1,length(aoa_grid));

        % Bartlett beamformer spectrum computation to estimate AoA
        for i=1:length(aoa_grid)

            w_ref = gen_a(N_r, d_r, aoa_grid(i));
            cost_a(i) = abs(w_ref'*R_x*w_ref);
            cost_sd(i) = abs(w_ref'*R_y*w_ref);

        end
        
         
        % UQ        
        [val pos] = local_max(cost_a,L); % peaks for beamformer spectrum
        doa_est_a = sort(aoa_grid(pos));
        A_RX_est_a = gen_a(N_r, d_r, sort(doa_est_a));   % estimated array manifold      
               
        % SD         
       [val pos] = local_max(cost_sd,L);
       doa_est_sd = sort(aoa_grid(pos));
       A_RX_est_sd = gen_a(N_r, d_r, sort(doa_est_sd));
            
       % %%%%% path gain estimation (step2) %%%%%%%%     
              
       % UQ     
       G1 = sqrt(SNR(iter))*ones(L,T_1);
       D = sqrt(1/L).*krp( transpose(G1),A_RX_est_a  );
       x = reshape(X_true, T_1*N_r,1);
       alpha_est_a = D\x;
       
       % Sigma - Delta       
       D_sd = sqrt(1/L).*krp( transpose(G1),R_n_pw*A_RX_est_sd  );
       y_pw = reshape(R_n_pw*Y_true, T_1*N_r,1);
       alpha_est_sd = D_sd\y_pw; 
        
       % %%%%%%%%%%%%%%%% DoD estimation (Step 3) %%%%%%%%%%%%%%%%
       
       % Estimated AoDs
       dod_est_a_dc_ind =  Dod_est_UQ(H_true, SNR(iter), u_grid, T_s,  A_RX_est_a, L, d_t);        
       dod_est_sd_dc_ind = Dod_est_SD(H_true, SNR(iter), u_grid, T_s, doa_est_sd, L, d_t, lev_s2, d_r);
       
       % Estimated array manifolds
       A_TX_est_a = gen_a_v2(N_t, d_t, u_grid(dod_est_a_dc_ind));
       A_TX_est_sd = gen_a_v2(N_t, d_t, u_grid(dod_est_sd_dc_ind));
       
       % Estimated channels
       H_est_a = sqrt(1/L).*A_RX_est_a*diag(alpha_est_a)*A_TX_est_a';
       H_est_sd = sqrt(1/L).*A_RX_est_sd*diag(alpha_est_sd)*A_TX_est_sd';
       
       err_uq_channel = err_uq_channel + norm(H_true - H_est_a, 'fro')^2;
       err_sd_channel = err_sd_channel + norm(H_true - H_est_sd, 'fro')^2;
       
  %%     %%%% AR based channel estimation
       
       S_1 = dftmtx(S_len_ar); 
       S_ar = sqrt(SNR(iter)/N_t).*S_1(1:N_t,:); % pilot matrix
       N_noise_ar = sqrt(1/2).*(randn(N_r,S_len_ar) + (1j).*randn(N_r,S_len_ar));
       X_ar = H_true*S_ar + N_noise_ar; % received signal at BS
       
       Y_q = sgn(X_ar); % 1-bit quantized signal
       R = norm(H_true,'fro');
      
       % Performing AR based channel estimation 
       [H_est_ar doa_est_ar dod_est_ar alpha_est_ar] = ar_mod_codebook(Y_q, S_ar,...
           R, L,1,d_r, M, d_t, Itermax);

       err_ar_channel = err_ar_channel + norm(H_est_ar - H_true,'fro')^2;
       
       %% UQ with full data

       [Doa_est_uqf Dod_est_uqf alpha_est_uqf] = SU_MIMO_ch_est_bf_codebook(X_ar,d_r,L,S_ar,1,M,d_t);
       H_est_uqf = sqrt(1/L).*gen_a(N_r, d_r, Doa_est_uqf)*diag(alpha_est_uqf)*gen_a(N_t, d_t, Dod_est_uqf)';
       
       err_uqf_channel = err_uqf_channel + norm(H_est_uqf - H_true,'fro')^2;
       
       [iter loop_inner]
      
        
    end
    
   
    
    MSE_uq_channel(iter) = err_uq_channel/ch_dr;
    MSE_uqf_channel(iter) = err_uqf_channel/ch_dr;
    MSE_sd_channel(iter) = err_sd_channel/ch_dr;
    MSE_ar_channel(iter) = err_ar_channel/ch_dr;
    
    
end


toc

%% Plotting results

figure
plot(SNR_dB, 10*log10(MSE_uq_channel), 'r-d','LineWidth',1.2)
hold on
plot(SNR_dB, 10*log10(MSE_uqf_channel), 'm-d','LineWidth',1.2)
hold on
plot(SNR_dB, 10*log10(MSE_sd_channel), 'b','LineWidth',1.2)
hold on
plot(SNR_dB, 10*log10(MSE_ar_channel), 'k-s','LineWidth',1.2)
grid on 
xlabel('SNR (dB)')
ylabel(' NMSE (dB) ')
legend('UQ','UQ full data','SD','AR')

% file_save = ['L_',num2str(L),'aspread_',num2str(aspread_aoa),'.mat']
% save(file_save)
